-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize UDF with parallel execution #713
base: main
Are you sure you want to change the base?
Conversation
Deploying datachain-documentation with Cloudflare Pages
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #713 +/- ##
==========================================
- Coverage 87.44% 87.21% -0.24%
==========================================
Files 114 116 +2
Lines 10898 10963 +65
Branches 1499 1508 +9
==========================================
+ Hits 9530 9561 +31
- Misses 990 1024 +34
Partials 378 378
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@@ -85,7 +85,6 @@ def run( | |||
udf_fields: "Sequence[str]", | |||
udf_inputs: "Iterable[RowsOutput]", | |||
catalog: "Catalog", | |||
is_generator: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used anywhere
|
||
with contextlib.closing( | ||
batching(warehouse.dataset_select_paginated, query) | ||
batching(warehouse.db.execute, query, ids_only=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure yet, but it looks like we don't need pagination here since we are only selecting IDs.
Should be tested on bigger scale and confirmed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
our sys__id
is 8 bytes, so on 1B scale this will take 8GB of memory by my calculation. I would still maybe leave it paginated.
n_workers=n_workers, | ||
processed_cb=processed_cb, | ||
download_cb=download_cb, | ||
) | ||
process_udf_outputs(warehouse, table, udf_results, udf, cb=generated_cb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are processing results and inserting them into DB in parallel processes now.
download_cb.relative_update(downloaded) | ||
if processed := result.get("processed"): | ||
processed_cb.relative_update(processed) | ||
if status in (OK_STATUS, NOTIFY_STATUS): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are doing updates above now for all types of signals, no need to process these signals here.
src/datachain/query/dispatch.py
Outdated
process_udf_outputs( | ||
warehouse, | ||
self.table, | ||
self.notify_and_process(udf_results, processed_cb), | ||
self.udf, | ||
cb=processed_cb, | ||
) | ||
warehouse.insert_rows_done(self.table) | ||
|
||
put_into_queue( | ||
self.done_queue, | ||
{"status": FINISHED_STATUS, "processed": processed_cb.processed_rows}, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not pass results into main process, write them right into DB here.
def notify_and_process(self, udf_results, processed_cb): | ||
for row in udf_results: | ||
put_into_queue( | ||
self.done_queue, | ||
{"status": NOTIFY_STATUS, "processed": processed_cb.processed_rows}, | ||
{"status": OK_STATUS, "processed": processed_cb.processed_rows}, | ||
) | ||
put_into_queue(self.done_queue, {"status": FINISHED_STATUS}) | ||
yield row |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Helper function to notify process before writing results into DB.
Is it possible to treat 1 parallel process as no parallel processes or raise error when only one parallel is specified? |
@dreadatour, while you are working on this, could you please also take a look at this example test:
Does this PR improve that test? Should it take that long? |
Sure, sounds reasonable 👍
Nice catch, let me take a look 🙏 |
Found an issue. This is because of this (basically everything runs in single process because batch size is 10k and number of records is 400). Couple tests:
|
Fixed: https://github.com/iterative/datachain/actions/runs/12372015291/job/34529325219?pr=713 (409 sec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work! I'\ve added a couple of small comments and a few thoughts / questions.
current_partition: Optional[int] = None | ||
batch: list[Sequence] = [] | ||
|
||
query_fields = [str(c.name) for c in query.selected_columns] | ||
# query_fields = [column_name(col) for col in query.inner_columns] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented out code
@@ -464,8 +465,8 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: | |||
|
|||
with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603 | |||
process.communicate(process_data) | |||
if process.poll(): | |||
raise RuntimeError("UDF Execution Failed!") | |||
if ret := process.poll(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would maybe put full variable name as ret
is not so common shortcut IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
retval
may be? 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Identifiers that exist for short scopes should be short." It is consumed in the next line.
So, this is okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, it could be renamed retval
(as you have proposed), exitcode
, retcode
, etc. But it's not necessary imo.
|
||
with contextlib.closing( | ||
batching(warehouse.dataset_select_paginated, query) | ||
batching(warehouse.db.execute, query, ids_only=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
our sys__id
is 8 bytes, so on 1B scale this will take 8GB of memory by my calculation. I would still maybe leave it paginated.
if self.is_batching: | ||
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL: | ||
ids = [row[0] for row in batch.rows] | ||
rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing that I'm worried about is this query in Clickhouse. There we have rows sorted and "packed" in granules by sys__id
which is primary key. It would be ideal if these id
batches are all sorted and "close" to each other, as otherwise we could end up in situation where big chunk of DB is read for every batch just because one id ended up in first granule / part, other id was in second one etc. This is because CH reads whole part / granule even if there is we need only one record from it.
Highlights
In this PR I am:
This will prevent multiple types conversion.
Before it was:
multiprocess.Queue
-> convert rows from Python types with msgpackmultiprocess.Queue
-> convert rows to Python types (maspack)multiprocess.Queue
-> convert from Python type (msgpack)multiprocess.Queue
-> convert them back into Python type (msgpack)After:
int
s into Pythonint
s is quick, stable, predictablemultiprocess.Queue
-> convert list of IDs with pickle (by default)In the end:
Test scenario
Simple script to check raw
parallel
setting only:This is very simple and basic scenario, but it helps us to test
parallel
setting only, without any overheads.Note
prefetch
is off in this case to measureparallel
onlyOverview
On the chart below there are two series: before optimization (blue) and after (green). On the X axis is parallel processes count, on the Y axis is number of total rows processed by UDF in parallel. This is valid for SQLite warehouse on my local machine.
As we can see, "before"
parallel
option does not affect performance at all, there is a strict limit on performance and it does not depends on number of parallel processes.The reason is because we pass rows into UDF for each parallel process via
multiprocess.Queue
and get results back the same way.Queue
performance is very limited. I wrote a simple script to testQueue
only and it is limited indeed. I have tried different ways for IPC (Pipes, ZeroMQ) and they all have this limit. This can be solved introducing external dependencies (Redis, RabbitMQ, etc), but it is not what we want for CLI tool."After" performs much better,
Queue
is used only to pass IDs in batches and is performant enough to show the performance boost depending on parallel processes count. It is not linear, because performance of SQLite warehouse is now the limit, but it is much better, stable and predictable.Also note "1 parallel process" performance is ~2.15 times slower than clean "no parallel processes" and this is basically overhead for using parallel processes and queues to read and pass IDs. on 2-3 parallel processes performance is the same as on "no parallel" and it is increasing over parallel processes count increasing.
Next I am going to measure the same numbers on ClickHouse DB warehouse, I suppose it is going to be much better and linear.
More measurements for those who love raw numbers
Before
Not parallel (for reference)
Parallel = 1 (edge case)
Parallel = 8
After
Not parallel (for reference)
Parallel = 1 (edge case)
Parallel = 8